close all
clear

folderName = '2023_01_13_4_2D2_qw_n2_6Er_modfreq3_1010_phase0';
prefix = 'scan';
paramVar ='evolution_dur';

plot_figure = 1;
export_data = 1;
stat_phase_pi = 0;

col1 = 115;
col2 = 128; 

row1 = 90; %difference must be even
row2 = 113; %starting site must be in the center
centersites = [(row2-row1+1)/2 (row2-row1+1)/2+1];
rlen = row2-row1+1;

[param_names, param_table] = get_batch_params(folderName); %param_table has extra zero at the end for some reason
paramInd = find(strcmp(param_names, paramVar));
alltvalues = param_table(:, paramInd);
alltlength = length(alltvalues);
atomMatrix = zeros(248,248);
uniquetvals = unique(alltvalues);
tlen = numel(unique(alltvalues));

doublonInd = find(strcmp(param_names, 'doublon_ramp_dur'));
doublon_ramp_dur_all = param_table(:, doublonInd);
doublon_ramp_dur_list = unique(doublon_ramp_dur_all);


%% Extract data

tlen_aux = (numel(alltvalues) - 1) / numel(doublon_ramp_dur_list);

% case 1: difference between atom positions = 0 (doublon)
% case 2: difference between atom positions = 1
% case 3: difference between atom positions > 1
Npost = zeros(tlen_aux, 3, 2);
Nrealizations = zeros(tlen_aux, 2);
bo = zeros(tlen_aux, rlen, 3, 2);
bo_density = zeros(tlen_aux, rlen, 3, 2);
corr_diff1_a = zeros(rlen, rlen, tlen_aux);
corr_diff2_a = zeros(rlen, rlen, tlen_aux);
corr_diff0_b = zeros(rlen, rlen, tlen_aux);
corr_diff2_b = zeros(rlen, rlen, tlen_aux);
corr_diff1_a_density = zeros(rlen, rlen, tlen_aux);
corr_diff2_a_density = zeros(rlen, rlen, tlen_aux);
corr_diff0_b_density = zeros(rlen, rlen, tlen_aux);
corr_diff2_b_density = zeros(rlen, rlen, tlen_aux);

for dd = 1:numel(doublon_ramp_dur_list)
    doublon_ramp_dur = doublon_ramp_dur_list(dd);
    start = 1+tlen_aux*(dd-1); 
    stop = tlen_aux*dd;
 
    for jj = 1:(stop-start+1)
        listing = dir(fullfile(folderName, [prefix '*' num2str(start+jj-1,'%03.f') 'atomMatrix.mat']));
        Nshots = size(listing,1);
        Nrealizations(jj, dd) = (col2-col1+1) * Nshots;
        
        for i = 1:Nshots
            load(fullfile(folderName, listing(i).name));
            for col = col1:col2
                if sum(atomMatrix(row1:row2, col)) == 2
                    [row, ~, ~] = find(atomMatrix(row1:row2, col));
                
                    if doublon_ramp_dur > 0                        
                        if diff(row) > 1
                            Npost(jj,3,dd) = Npost(jj,3,dd) + 1;
                            bo(jj,:,3,dd) = bo(jj,:,3,dd) + atomMatrix(row1:row2, col)';
                            corr_diff2_b(row(1), row(2), jj) = corr_diff2_b(row(1), row(2), jj) + 1;
                            corr_diff2_b(row(2), row(1), jj) = corr_diff2_b(row(2), row(1), jj) + 1;
                        else
                            Npost(jj,1,dd) = Npost(jj,1,dd) + 1;
                            doublon_pos = row(2); % quic
                            atomRow = zeros(1, rlen);
                            atomRow(doublon_pos) = 2;
                            bo(jj,:,1,dd) = bo(jj,:,1,dd) + atomRow;
                            corr_diff0_b(doublon_pos, doublon_pos, jj) = corr_diff0_b(doublon_pos, doublon_pos, jj) + 2;
                        end
                    else
                        if diff(row) > 1
                            Npost(jj,3,dd) = Npost(jj,3,dd) + 1;
                            bo(jj,:,3,dd) = bo(jj,:,3,dd) + atomMatrix(row1:row2, col)';
                            corr_diff2_a(row(1), row(2), jj) = corr_diff2_a(row(1), row(2), jj) + 1;
                            corr_diff2_a(row(2), row(1), jj) = corr_diff2_a(row(2), row(1), jj) + 1;
                        else
                            Npost(jj,2,dd) = Npost(jj,2,dd) + 1;
                            bo(jj,:,2,dd) = bo(jj,:,2,dd) + atomMatrix(row1:row2, col)';
                            corr_diff1_a(row(1), row(2), jj) = corr_diff1_a(row(1), row(2), jj) + 1;
                            corr_diff1_a(row(2), row(1), jj) = corr_diff1_a(row(2), row(1), jj) + 1;
                        end
                    end

                end
            end
            corr_diff1_a_density(:,:,jj) = corr_diff1_a(:,:,jj) ./ repmat(Npost(jj,2,1), rlen, rlen);
            corr_diff2_a_density(:,:,jj) = corr_diff2_a(:,:,jj) ./ repmat(Npost(jj,3,1), rlen, rlen);
            corr_diff0_b_density(:,:,jj) = corr_diff0_b(:,:,jj) ./ repmat(Npost(jj,1,2), rlen, rlen);
            corr_diff2_b_density(:,:,jj) = corr_diff2_b(:,:,jj) ./ repmat(Npost(jj,3,2), rlen, rlen);
        end

    end

    for nn = 1:3
        bo_density(:,:,nn,dd) = bo(:,:,nn,dd) ./ repmat(Npost(:,nn,dd), 1, size(bo,2));
    end
end

bo_density(isnan(bo_density)) = 0;
corr_diff1_a_density(isnan(corr_diff1_a_density)) = 0;
corr_diff2_a_density(isnan(corr_diff2_a_density)) = 0;
corr_diff0_b_density(isnan(corr_diff0_b_density)) = 0;
corr_diff2_b_density(isnan(corr_diff2_b_density)) = 0;

Nrealizations_aux = sum(Nrealizations, 2);


%% Combine data

Ndiff1_a = Npost(:,2,1);
Ndiff2_a = Npost(:,3,1);
Ndiff0_b = Npost(:,1,2);
Ndiff2_b = Npost(:,3,2);

Ndiff1_eff = 2 * Ndiff1_a;
Ndiff0_eff = 2 * Ndiff0_b;

bo_diff2_a = bo_density(:,:,3,1);
bo_diff2_b = bo_density(:,:,3,2);
bo_diff1_a = bo_density(:,:,2,1);
bo_diff0_b = bo_density(:,:,1,2);

Npost_total = Ndiff2_a + Ndiff2_b + Ndiff1_eff + Ndiff0_eff;
bo_density_weighted = repmat(Ndiff2_a, 1, rlen) .* bo_diff2_a + repmat(Ndiff2_b, 1, rlen) .* bo_diff2_b ...
    + repmat(Ndiff0_eff, 1, rlen) .* bo_diff0_b + repmat(Ndiff1_eff, 1, rlen) .* bo_diff1_a;
bo_density_weighted = bo_density_weighted ./ Npost_total;

corr_weighted = zeros(rlen, rlen, tlen_aux);
for tt = 1:tlen_aux
    corr_weighted(:,:,tt) = repmat(Ndiff2_a(tt), rlen, rlen) .* corr_diff2_a_density(:,:,tt) + repmat(Ndiff2_b(tt), rlen, rlen) .* corr_diff2_b_density(:,:,tt) ...
        + repmat(Ndiff0_eff(tt), rlen, rlen) .* corr_diff0_b_density(:,:,tt) + repmat(Ndiff1_eff(tt), rlen, rlen) .* corr_diff1_a_density(:,:,tt);
    corr_weighted(:,:,tt) = corr_weighted(:,:,tt) ./ repmat(Npost_total(tt), rlen, rlen);
end

% compress data (if there are several data points per batch with same parameters)
Nrealizations_compressed = zeros(tlen, 1);
Npost_compressed = zeros(tlen, 1);

bo_compressed = zeros(tlen, size(bo,2));
corr_compressed = zeros(rlen, rlen, tlen);
tvalues = alltvalues(1:size(bo_density_weighted,1));
for tt = 1:tlen
    bo_compressed(tt,:) = mean(bo_density_weighted(tvalues == uniquetvals(tt),:), 1);
    corr_compressed(:,:,tt) = mean(corr_weighted(:,:,tvalues==uniquetvals(tt)), 3);
    Npost_compressed(tt) = sum(Npost_total(tvalues==uniquetvals(tt)));
    Nrealizations_compressed(tt) = sum(Nrealizations_aux(tvalues==uniquetvals(tt)));
end

P_post_compressed = Npost_compressed ./ Nrealizations_compressed;
P_post_err = sqrt(P_post_compressed .* (1 - P_post_compressed) ./ Nrealizations_compressed);


%% convert time to units of tau

J = 10.6; 
tau_ms = ( 1/(2*pi*J) ) * 1000;
tau_data_list = uniquetvals / tau_ms;


%% Plot data

if plot_figure
    fig1 = figure('Name', ['Two-particle quantum walk, phase=', num2str(stat_phase_pi), 'pi']);
    imagesc([12, 1-12], uniquetvals, bo_compressed)
    colorbar
    clim([0 1.1])
    xlabel('Row')
    ylabel('Time (ms)')
end


%% Post selection rate

if plot_figure
    fig3 = figure();
    hold on
    errorbar(uniquetvals, P_post_compressed, P_post_err, 'o', 'Linewidth', 1.5)
    xlabel('Time (ms)')
    ylabel('Post-selection rate')
    ylim([0 1])
    hold off
end


%% Export data

if export_data
    save(strcat('mat_files\m', folderName, '.mat'), ...
        'Npost', 'Nrealizations','bo_compressed', 'corr_compressed', 'uniquetvals');
end


%% Plot correlation at specific time 

corr_tt = find(uniquetvals == 36.25);
if plot_figure
    fig_corr = figure();
    imagesc([12, 1-12], [12, 1-12], corr_compressed(:,:,corr_tt))
    colorbar
    clim([0 0.1])
    axis image
    xlabel('Row')   
    ylabel('Row')
    set(gca,'YDir','normal')
    title(strcat('t = ', 32, num2str(uniquetvals(corr_tt)), ' ms'))
end

